from _fitting import fitting
from pyrap.functionals import *
import numpy as NUM

class fitserver(object):
    """Create a `fitserver instance. The object can be created without
    arguments (in which case it is assumed to be a real fitter), or with
    the arguments specifying the number of unknowns to be solved for (a number
    not relevant in practice); and the type of solution: real, complex,
    conjugate (complex with both the unknown and its conjugate in the
    condition equations), separable complex, asreal complex with the real and
    imaginary part seen as independent unknowns. All solutions need a model
    (specified as a :mod:`pyrap.functionals`. All solutions are done using an
    SVD type method. A collinearity factor can be specified, which is in
    essence the sine squared of the minimum angle between two normal equation
    columns that are still to be considered independent. For automatic
    non-linear solutions, a Levenberg-Marquardt factor (see
    `Note 224 <../../casacore/doc/notes/224.html>`_) is used, which can be
    specified as well.

    In the case of non-linear solutions that have to be handled by the system,
    an initial estimate for the model parameters is necessary.

    :param n:      number of unknowns
    :param ftype:  type of solution
                   Allowed: real, complex, separable, asreal, conjugate
    :param colfac: collinearity factor
    :param lmfac:  Levenberg-Marquardt factor
    :param fid:    the id of a sub-fitter

    """

    def __init__(self, n=0, m=1, ftype=0, colfac=1.0e-8, lmfac=1.0e-3):
        self._fitids = []
        self._typeids = {"real": 0, "complex": 1, "separable": 3,
                         "asreal": 7, "conjugate": 11}
        self._fitproxy = fitting()
        fid = self.fitter(n=n, ftype=ftype, colfac=colfac, lmfac=lmfac)
        if fid != 0:
            raise RuntimeError("System problem creating fitter server")

    def fitter(self, n=0, ftype="real", colfac=1.0e-8, lmfac=1.0e-3):
        """Create a sub-fitter (which can be used in the same way as a fitter
        default fitter). This function returns an identification, which has to
        be used in the `fid` argument of subsequent calls. The call can
        specify the standard constructor arguments (`n`, `type`, `colfac`,
        `lmfac`), or can specify them later in a :meth:`set` statement.

        :param n:      number of unknowns
        :param ftype:  type of solution
                       Allowed: real, complex, separable, asreal, conjugate
        :param colfac: collinearity factor
        :param lmfac:  Levenberg-Marquardt factor
        :param fid:    the id of a sub-fitter

        """
        fid = self._fitproxy.getid()
        ftype = self._gettype(ftype)
        n = len(self._fitids)
        if 0 <= fid < n:
            self._fitids[fid] = {}
        elif fid == n:
            self._fitids.append({})
        else:
            # shouldn't happen
            raise RangeError("fit id out of range")
        self.init(n=n, ftype=ftype, colfac=colfac, lmfac=lmfac, fid=fid)
        return fid

    def init(self,  n=0, ftype="real", colfac=1.0e-8, lmfac=1.0e-3, fid=0):
        """Set selected properties of the fitserver instance. Like in the
        constructor, the number of unknowns to be solved for; the number of
        simultaneous solutions; the ftype and the collinearity and
        Levenberg-Marquardt factor can be specified. Individual values can be
        overwritten with the :meth:`set` function.

        :param n: number of unknowns
        :param ftype: type of solution
                      Allowed: real, complex, separable, asreal, conjugate
        :param colfac:	collinearity factor
        :param lmfac: Levenberg-Marquardt factor
        :param fid: the id of a sub-fitter
        """
        ftype = self._gettype(ftype)
        self._fitids[fid]["stat"] = False
        self._fitids[fid]["solved"] = False
        self._fitids[fid]["haserr"] = False
        self._fitids[fid]["fit"] = False
        self._fitids[fid]["looped"] = False
        if self._fitproxy.init(fid, n, ftype, colfac, lmfac):
            self._fitids[fid]["stat"] = self._getstate(fid)
        else:
            return False

    def _gettype(self, ftype):
        if isinstance(ftype, str):
            ftype = ftype.lower()
            if not self._typeids.has_key(ftype):
                raise TypeError("Illegal fitting type")
            else:
                return self._typeids[ftype]
        elif isinstance(ftype, int):
            if not ftype in self._typeids.values():
                raise TypeError("Illegal fitting type")
        else:
            raise TypeError("Illegal fitting type")
        return ftype

    def _settype(self, ftype=0):
        for k, v in self._typeids.iteritems():
            if ftype == v:
                return k
        return "real"

    def _checkid(self, fid=0):
        if not (0 <= fid < len(self._fitids)
                and isinstance(self._fitids[fid], dict)
                and self._fitids[fid].has_key("stat")
                and isinstance(self._fitids[fid]["stat"], dict)):
            raise ValueError("fit id out of range")

    def _reshape(self, fid=0):
        pass

    def  _getstate(self, fid):
        d = self._fitproxy.getstate(fid)
        if d.has_key("typ"):
            d["typ"] = self._settype(d["typ"])
        return d

    def set(self, n=None, ftype=None, colfac=None, lmfac=None, fid=0):
        """Set selected properties of the fitserver instance. All unset
        properties remain the same (in the :meth:`init` method all properties
        are (re-)initialized). Like in the constructor, the number of unknowns
        to be solved for; the number of simultaneous solutions; the ftype (as
        code); and the collinearity and Levenberg-Marquardt factor can be
        specified.

        :param n: number of unknowns
        :param ftype: type of solution
                    Allowed: real, complex, separable, asreal, conjugate
        :param colfac:	collinearity factor
        :param lmfac: Levenberg-Marquardt factor
        :param fid: the id of a sub-fitter
        """
        self._checkid(fid)
        if ftype is None:
            ftype = -1
        else:
            ftype = self._gettype(ftype)
        if  n is None:
            n = -1
        elif n < 0:
            raise ValueError("Illegal set argument n")
        if  colfac is None:
            colfac = -1
        elif colfac < 0:
            raise ValueError("Illegal set argument colfac")
        if  lmfac is None:
            lmfac = -1
        elif lmfac < 0:
            raise ValueError("Illegal set argument lmfac")

        self._fitids[fid]["stat"] = False
        self._fitids[fid]["solved"] = False
        self._fitids[fid]["haserr"] = False
        self._fitids[fid]["fit"] = True
        self._fitids[fid]["looped"] = False
        if n != -1 or ftype != -1 or colfac != -1 or lmfac != -1:
            if not self._fitproxy.set(fid, n, ftype, colfac, lmfac):
                return False
        self._fitids[fid]["stat"] = self._getstate(fid)
        return True


    def done(self, fid=0):
        self._checkid(fid)
        self._fitids[fid] = {}
        self._fitproxy.done(fid)

    def reset(self, fid=0):
        """Reset the object's resources to its initialized state.

        :param fid: the id of a sub-fitter
        """
        self._checkid(fid)
        self._fitids[fid]["solved"] = False
        self._fitids[fid]["haserr"] = False
        if not self._fitids[fid]["looped"]:
            return self._fitproxy.reset(fid)
        else:
            self._fitids[fid]["looped"] = False
        return True

    def getstate(self, fid=0):
        """Obtain the state of the fitter object or a sub-fitter.

        :param fid: the id of a sub-fitter
        """
        self._checkid(fid)
        return self._fitids[fid]["stat"]

    def clearconstraints(self, fid=0):
        self._checkid(fid)
        self._fitids[fid]["constraint"] = {}

    def addconstraint(self, x, y=0, fnct=None, fid=0):
        self._checkid(fid)
        i = 0
        if self._fitids[fid].has_key("constraint"):
            i = len(self._fitids[fid]["constraint"])
        else:
            self._fitids[fid]["constraint"] = {}
        # dict key needs to be string
        i = str(i)
        self._fitids[fid]["constraint"][i] = {}
        if isinstance(fnct, functional):
            self._fitids[fid]["constraint"][i]["fnct"] = fnct.todict()
        else:
            self._fitids[fid]["constraint"][i]["fnct"] = \
                functional("hyper", len(x)).todict()
        self._fitids[fid]["constraint"][i]["x"] = [float(v) for v in x]
        self._fitids[fid]["constraint"][i]["y"] = float(y)
        print self._fitids[fid]["constraint"]

    def fitpoly(self, n, x, y, sd=None, wt=1.0, fid=0):
        if self.set(n=n+1, fid=fid):
            return self.linear(poly(n), x, y, sd, wt, fid)

    def fitspoly(self, n, x, y, sd=None, wt=1.0, fid=0):
        """Create normal equations from the specified condition equations, and
        solve the resulting normal equations. It is in essence a combination

        The method expects that the properties of the fitter to be used have
        been initialized or set (like the number of simultaneous solutions m;
        the type; factors). The main reason is to limit the number of
        parameters on the one hand, and on the other hand not to depend
        on the actual array structure to get the variables and type. Before
        fitting the x-range is normalized to values less than 1 to cater for
        large difference in x raised to large powers. Later a shift to make x
        around zero will be added as well.

        :param n: the order of the polynomial to solve for
        :param x: the abscissa values
        :param y: the ordinate values
        :param sd: standard deviation of equations (one or more values used
                   cyclically)
        :param wt: an optional alternate for `sd`
        :param fid: the id of the sub-fitter (numerical)

        Example::

            fit = fitserver()
            x = N.arange(1,11) # we have values at 10 'x' values
            y = 2. + 0.5*x - 0.1*x**2 # which are 2 +0.5x -0.1x^2
            fit.fitspoly(3, x, y) # fit a 3-degree polynomial
            print fit.solution(), fit.error() #  show solution and their errors

        """
        a = max(abs(max(x)), abs(min(x)))
        if a == 0: a = 1
        a = 1.0/a
        b = NUM.power(a, range(n+1))
        if self.set(n=n+1, fid=fid):
            return self.linear(poly(n), x, y, sd, wt, fid)
        if self.set(n=n+1, fid=fid):
            self.linear(poly(n), x*a, y, sd, wt, fid)
            self._fitids[fid]["sol"] *= b
            self._fitids[fid]["error"] *= b

    def fitavg(self, y, sd=None, wt=1.0, fid=0):
        if self.set(n=1, fid=fid):
            return self.linear(compiled("p"), [], y, sd, wt, fid)

    def _fit(self, **kw):
        fitfunc = kw.pop("fitfunc")
        sd = kw.pop("sd")
        fid = kw.pop("fid")
        kw["id"] = fid
        if not isinstance(kw["fnct"], functional):
            raise TypeError("No or illegal functional")
        if not self.set(n=kw["fnct"].npar(), fid=fid):
            raise ValueError("Illegal fit id")
        fnct = kw["fnct"]
        kw["fnct"] = fnct.todict()
        self.reset(fid)
        x = self._as_array(kw["x"])
        if x.ndim > 1 and fnct.ndim() == x.ndim:
            x = x.flatten()
        y = self._as_array(kw["y"])
        if y.ndim > 1 and fnct.ndim() == y.ndim:
            x = y.flatten()
        wt = self._as_array(kw["wt"])
        if sd is not None:
            sd = self._as_array(sd)
            wt = sd.copy()
            wt[sd == 0] = 1
            wt = 1/abs(wt * NUM.conjugate(wt))
            wt[NUM.logical_or(sd == -1, sd == 0)] = 0
        ftype = fitfunc
        dtype = 'float'
        if (self.getstate(fid)["typ"] != "real"
            or NUM.iscomplexobj(x) \
            or NUM.iscomplexobj(y) \
            or NUM.iscomplexobj(wt) ):
            ftype = "cx%s" % fitfunc
            dtype = 'complex'
        kw["x"] = self._as_array(x, dtype)
        kw["y"] = self._as_array(y, dtype)
        kw["wt"] = self._as_array(wt, dtype)
        if not self._fitids[fid].has_key("constraint"):
            self._fitids[fid]["constraint"] = {}
        kw["constraint"] =  self._fitids[fid]["constraint"]
        func = getattr(self._fitproxy, ftype)
        result = func(**kw)
        self._fitids[fid].update(result)
        self._fitids[fid]["solved"] = True
        self._fitids[fid]["haserr"] = True
        self._fitids[fid]["looped"] = False

    def functional(self, fnct, x, y, sd=None, wt=1.0, mxit=50, fid=0):
        """Ths will make a non-linear least squares solution for the points
        through the ordinates at the abscissa values, using the specified
        `fnct`. Details can be found in the :meth:`linear` description.

        :param fnct: the functional to fit
        :param x: the abscissa values
        :param y: the ordinate values
        :param sd: standard deviation of equations (one or more values used
                   cyclically)
        :param wt: an optional alternate for `sd`
        :param mxit: the maximum number of iterations
        :param fid: the id of the sub-fitter (numerical)

        """
        self._fit(fitfunc="functional", fnct=fnct, x=x, y=y, sd=sd, wt=wt,
                  mxit=mxit, fid=fid)
    nonlinear = functional

    def linear(self, fnct, x, y, sd=None, wt=1.0, fid=0):
        """Makes a linear least squares solution for the points through the
        ordinates at the x values, using the specified fnct. The x can be of
        any dimension, depending on the number of arguments needed in the
        functional evaluation. The values should be given in the order:
        x0[1], x0[2], ..., x1[1], ..., xn[m] if there are n observations,
        and m arguments. x should be a vector of m*n length; y (the
        observations) a vector of length n.

        :param fnct: the functional to fit
        :param x: the abscissa values
        :param y: the ordinate values
        :param sd: standard deviation of equations (one or more values used
                   cyclically)
        :param wt: an optional alternate for `sd`
        :param fid: the id of the sub-fitter (numerical)

        Example::

            #

        """
        self._fit(fitfunc="linear", fnct=fnct, x=x, y=y, sd=sd, wt=wt, fid=fid)

    def _getval(self, valname, fid):
        self._checkid(fid)
        if not self._fitids[fid]["solved"]:
            raise RuntimeError("Not solved yet")
        return self._fitids[fid][valname]

    def solution(self, fid=0):
        """Return the solution for the fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("sol", fid)

    def rank(self, fid=0):
        """Obtain the rank (in SVD sense) of a fit. The :meth:`constraint`
        method will show the equations that are orthogonal to the existing
        ones, and which will make the solution possible.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("rank", fid)

    def deficiency(self, fid=0):
        """Obtain the missing rank (in SVD sense) of a fit. The
        :meth:`constraint` method will show the equations that are orthogonal
        to the existing ones, and which will make the solution possible.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("deficiency", fid)

    def chi2(self, fid=0):
        """Obtain the chi squared of a fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("chi2", fid)

    def sd(self, fid=0):
        """Obtain the standard deviation per unit of weight of a fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("sd", fid)

    def mu(self, fid=0):
        """Obtain the standard deviation per condition equation of a fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("mu", fid)

    stddev = mu

    def covariance(self, fid=0):
        """Obtain the covariance matrix of a fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("covar", fid)

    def error(self, fid=0):
        """Obtain the errors in the unknowns of a fit.

        :param fid: the id of the sub-fitter (numerical)

        """
        return self._getval("error", fid)

    def constraint(self, n=-1, fid=0):
        """Obtain the set of orthogonal equations that make the solution of
        the rank deficient normal equations possible.

        :param fid: the id of the sub-fitter (numerical)

        """
        c = self._getval("constr", fid)
        if n < 0 or n > self.deficiency(fid):
            return c
        else:
            raise RuntimeError("Not yet implemented")

    def fitted(self, fid=0):
        """Test if enough Levenberg-Marquardt loops have been done. It returns
        True if no improvement possible.

        :param fid: the id of the sub-fitter (numerical)
        """
        self._checkid(fid)
        return not (self._fitids[fid]["fit"] > 0
                    or self._fitids[fid]["fit"] < -0.001)

    def _as_array(self, v, dtype=None):
        if not hasattr(v, "__len__"):
            v = [v]
        return NUM.asarray(v, dtype)
